import torch
import numpy as np
import random
import os
from typing import Optional


def seed_everything(seed: int):
    """Set random seed to ensure experiment reproducibility"""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def average_eeg_trials(eeg_data):
    """
    Average 4 repeated EEG signals
    Input: [batch_size, 4, 63, 250]
    Output: [batch_size, 63, 250]
    """
    return torch.mean(eeg_data, dim=1)


def save_checkpoint(model, optimizer, epoch, loss, path):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, path)


def load_checkpoint(model, optimizer, path, device):
    """Load model checkpoint"""
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optimizer, epoch, loss


def compute_correlation(pred_eeg, true_eeg):
    """Calculate correlation between predicted EEG and true EEG"""
    # Flatten tensors for correlation calculation
    pred_flat = pred_eeg.flatten(start_dim=1)
    true_flat = true_eeg.flatten(start_dim=1)
    
    # Calculate correlation coefficient for each sample
    correlations = []
    for i in range(pred_flat.size(0)):
        corr = torch.corrcoef(torch.stack([pred_flat[i], true_flat[i]]))[0, 1]
        correlations.append(corr)
    
    return torch.stack(correlations)


